import numpy as np
import torch 
from torch.utils.data import Dataset
import torchvision.transforms as tfs
import cv2
from PIL import Image
import pandas as pd
import glob
class OCT(Dataset):
    def __init__(self, 
                 csv_path, 
                 image_root_path='',
                 image_size=32,
                 shuffle=True,
                 seed=123,
                 verbose=True,
                 train_cols=['target'],
                 mode='train'):
        
    
        # load data from csv
        self.df = pd.read_excel(csv_path) 
        self._num_images = len(self.df)
            
        # shuffle data
        if shuffle:
            data_index = list(range(self._num_images))
            np.random.seed(seed)
            np.random.shuffle(data_index)
            self.df = self.df.iloc[data_index]
        
        
        assert image_root_path != '', 'You need to pass the correct location for the dataset!'

        self.select_cols = ['target']  # this var determines the number of classes
        self.value_counts_dict = self.df[self.select_cols[0]].value_counts().to_dict()
        
        self.mode = mode
        self.image_size = image_size
        
        self._images_list = self.df.iloc[:,0].values.tolist()
        self._labels_list = self.df[train_cols].values.tolist()
    
        if verbose:
            if True:
                print ('-'*30)
                self.imratio = self.value_counts_dict[1]/(self.value_counts_dict[0]+self.value_counts_dict[1])
                print('Found %s images in total, %s positive images, %s negative images'%(self._num_images, self.value_counts_dict[1], self.value_counts_dict[0] ))
                print ('%s: imbalance ratio is %.4f'%(self.select_cols[0], self.imratio ))
                print ('-'*30) 
            
    @property        
    def class_counts(self):
        return self.value_counts_dict
    
    @property
    def imbalance_ratio(self):
        return self.imratio

    @property
    def num_classes(self):
        return len(self.select_cols)
       
    @property  
    def data_size(self):
        return self._num_images 
    
    def image_augmentation(self, image):
        img_aug = tfs.Compose([tfs.RandomAffine(degrees=(-15, 15), translate=(0.05, 0.05), scale=(0.95, 1.05), fill=128)]) # pytorch 3.7: fillcolor --> fill
        image = img_aug(image)
        return image
    
    def __len__(self):
        return self._num_images
    
    def __getitem__(self, idx):
        uid = self._images_list[idx]
        tmp = np.zeros([1,1,1])
        for i in range(7):
          path = glob.glob("/home/dixzhu/Kailuan_stroke/StrokeForWu/"+str(uid)+"_*Volume/"+str(uid)+"_*OCT0"+str(i+1)+".jpg")
          #print(path)
          if len(path) == 0:
            print('uid:' + str(uid))
            print(path)
            break
          image = cv2.imread(path[0], 0)
          #print('imread: '+ str(image.shape))
          image = Image.fromarray(image)
          
          image = np.array(image)
       
          # resize and normalize; e.g., ToTensor()
          image = cv2.resize(image, dsize=(self.image_size*3, self.image_size), interpolation=cv2.INTER_LINEAR)  
          image = np.expand_dims(image,axis=2)
          image = image/255.0
          if i == 0:
            tmp = image
          else:
            tmp = np.concatenate([tmp,image],axis=2)
          #print(tmp.shape)
        image = tmp
        image = image.transpose((2, 0, 1)).astype(np.float32)
        image = np.expand_dims(image,axis=0)
        label = np.array(self._labels_list[idx]).reshape(-1).astype(np.float32)   
        return image, label

    def get_labels(self):
        return np.array(self._labels_list).reshape(-1)

if __name__ == '__main__':
    root = '/home/dixzhu/Kailuan_stroke/'
    traindSet = OCT(csv_path=root+'train.xlsx', image_root_path=root, image_size=256, mode='train')    
    trainloader =  torch.utils.data.DataLoader(traindSet, batch_size=32, num_workers=2, drop_last=True, shuffle=True)
    # convert jpgs to binary file.
    for idx, data in enumerate(traindSet):
      train_data, train_label = data
      if train_data.shape[1] < 7:
        continue
      print(idx)
      if idx==0:
        trX = train_data
        trY = train_label
      else:
        trX = np.concatenate((trX, train_data), axis=0)
        trY = np.concatenate((trY, train_label), axis=0)
    np.save('/home/dixzhu/data/OCT_X',trX)
    np.save('/home/dixzhu/data/OCT_Y',trY)
    print(trX.shape)
    print(trY.shape)
    
